// Copyright 1998-2016 Glenn McIntosh
// licensed under the GNU General Public Licence version 3
#pragma once

/** @file ntt.h
	number theoretic transform functions
	*/
// include files
#include <cstddef>
#include <cstdint>
#include <cassert>
#include <vector>

namespace math
{
/** base type used to hold part of multiprecision int */
using Base = uint32_t;

/** base type used for intermediate results, able to hold Base squared */
using Base2= uint64_t;

/** multiprecision integer container */
using Mpint = std::vector<Base>;

/** number theoretic transform a vector (DIF, no bit reversal permutation)
	@param order is the order of the ring
	@param scale is (modulus-1)/order
	@param root is the root of the ring
	@param x is the vector to be transformed in place (length a power of two)
	*/
template<Base order = 1u<<30, Base scale = 3u, Base root = 125u>
void ntt(Mpint &x);

/** number theoretic transform a vector (DIT, no bit reversal permutation)
	@param order is the order of the ring
	@param scale is (modulus-1)/order
	@param root is the root of the ring
	@param x is the vector to be transformed in place (length a power of two)
	*/
template<Base order = 1u<<30, Base scale = 3u, Base root = 125u>
void intt(Mpint &x);

/** convolve vectors
	@param order is the order of the ring
	@param scale is (modulus-1)/order
	@param root is the root of the ring
	@param x is the data to be convolved (length a power of two)
	@param y is the data to be convolved with (same length)
	*/
template<Base order = 1u<<30, Base scale = 3u, Base root = 125u>
void convolve(Mpint &x, Mpint &y);

namespace
{
constexpr int basebits = sizeof(Base2)*8/2;

template<typename T> constexpr
T sqr(T x) {return x*x;}

// powermod
template<Base modulus, Base root>
constexpr Base powermod(Base p)
{
	return
		p ?
			p&1u ?
				(sqr(Base2{powermod<modulus, root>(p>>1)}) % modulus) * root % modulus
			:
				sqr(Base2{powermod<modulus, root>(p>>1)}) % modulus
		:
			1u
		;
}

// bit function
constexpr int ctz(int m) {return __builtin_ctz(m);}
}

// number theoretic transform (DIF, no bit reversal permutation)
template<Base order = 1u<<30, Base scale = 3u, Base root = 125u>
void ntt(Mpint &x)
{
	constexpr Base modulus = order*scale+1u;
	int n = x.size();
	assert((n&n-1) == 0);
	if (n == 1) return;

	// decimate in frequency
	for (int m = n/2; m; m >>= 1)
	{
		Base cs1 = powermod<modulus, root>(order/2u/m);
		for (int j = 0; j < n; j += 2*m)
		{
			Base2 cs = 1;
			for (int w = j; w < j+m; ++w)
			{
				Base2 x0 = x[w], x1 = x[w+m];
				x[w] = (x0+x1)-(modulus-(x0+x1)>>basebits&modulus); x[w+m] = (x0-x1+(x0-x1>>basebits&modulus)) * cs % modulus;
				cs = cs*cs1 % modulus;
			}
		}
	}
}

// number theoretic transform (DIT, no bit reversal permutation)
template<Base order = 1u<<30, Base scale = 3u, Base root = 125u>
void intt(Mpint &x)
{
	constexpr Base modulus = order*scale+1u;
	int n = x.size();
	assert((n&n-1) == 0);
	if (n == 1) return;

	// divide by log2(n)
	for (int i = 0; i < n; ++i)
		x[i] = (Base2{x[i]}+Base2{~x[i]+1 & n-1}*modulus) >> ctz(n);

	// decimate in time
	for (int m = 1; m < n; m <<= 1)
	{
		Base cs1 = powermod<modulus, root>(order-order/2/m);
		for (int j = 0; j < n; j += 2*m)
		{
			Base2 cs = 1;
			for (int w = j; w < j+m; ++w)
			{
				Base2 x0 = x[w], x1 = x[w+m];
				x1 = x1*cs % modulus;
				x[w] = (x0+x1)-(modulus-(x0+x1)>>basebits&modulus); x[w+m] = x0-x1+(x0-x1>>basebits&modulus);
				cs = cs*cs1 % modulus;
			}
		}
	}
}

// number theoretic convolution
template<Base order = 1u<<30, Base scale = 3u, Base root = 125u>
void convolve(Mpint &x, Mpint &y)
{
	constexpr Base modulus = order*scale+1; // prime
	int n = x.size();
	assert((n&n-1) == 0);
	assert(n == y.size());

	// forward transform
	ntt<order, scale, root>(x);
	ntt<order, scale, root>(y);

	// multiply
	for (int i = 0; i < n; ++i)
		x[i] = Base2{x[i]}*Base2{y[i]} % modulus;

	// inverse transform
	intt<order, scale, root>(x);
}
}
